
package com.jstarcraft.ai.jsat.datatransform;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import com.jstarcraft.ai.jsat.DataSet;
import com.jstarcraft.ai.jsat.classifiers.CategoricalData;
import com.jstarcraft.ai.jsat.classifiers.DataPoint;
import com.jstarcraft.ai.jsat.linear.DenseMatrix;
import com.jstarcraft.ai.jsat.linear.DenseVector;
import com.jstarcraft.ai.jsat.linear.Matrix;
import com.jstarcraft.ai.jsat.linear.Vec;

/**
 * Principle Component Analysis is a method that attempts to create a basis of
 * the given space that maintains the variance in the data set while eliminating
 * correlation of the variables. <br>
 * When a full basis is formed, the dimensionality will remain the same, but the
 * data will be transformed to a new space. <br>
 * PCA is particularly useful when a small number of basis can explain most of
 * the variance in the data set that is not related to noise, maintaining
 * information while reducing the dimensionality of the data. <br>
 * <br>
 * PCA works only on the numerical attributes of a data set. <br>
 * For PCA to work correctly, a {@link ZeroMeanTransform} should be applied to
 * the data set first. If not done, the first dimension of PCA may contain noise
 * and become uninformative, possibly throwing off the computation of the other
 * PCs
 * 
 * @author Edward Raff
 * @see ZeroMeanTransform
 */
public class PCA implements DataTransform {

    private static final long serialVersionUID = 8736609877239941617L;
    /**
     * The transposed matrix of the Principal Components
     */
    private Matrix P;
    private int maxPCs;
    private double threshold;

    /**
     * Creates a new object for performing PCA that stops at 50 principal
     * components. This may not be optimal for any particular dataset
     *
     */
    public PCA() {
        this(50);
    }

    /**
     * Performs PCA analysis using the given data set, so that transformations may
     * be performed on future data points. <br>
     * <br>
     * NOTE: The maximum number of PCs will be learned until a convergence threshold
     * is meet. It is possible that the number of PCs computed will be equal to the
     * number of dimensions, meaning no dimensionality reduction has occurred, but a
     * transformation of the dimensions into a new space.
     * 
     * @param dataSet the data set to learn from
     */
    public PCA(DataSet dataSet) {
        this(dataSet, Integer.MAX_VALUE);
    }

    /**
     * Performs PCA analysis using the given data set, so that transformations may
     * be performed on future data points.
     * 
     * @param dataSet the data set to learn from
     * @param maxPCs  the maximum number of Principal Components to let the
     *                algorithm learn. The algorithm may stop earlier if all the
     *                variance has been explained, or the convergence threshold has
     *                been met. Note, the computable maximum number of PCs is
     *                limited to the minimum of the number of samples and the number
     *                of dimensions.
     */
    public PCA(DataSet dataSet, int maxPCs) {
        this(dataSet, maxPCs, 1e-4);
    }

    /**
     * Creates a new object for performing PCA
     *
     * @param maxPCs the maximum number of Principal Components to let the algorithm
     *               learn. The algorithm may stop earlier if all the variance has
     *               been explained, or the convergence threshold has been met.
     *               Note, the computable maximum number of PCs is limited to the
     *               minimum of the number of samples and the number of dimensions.
     */
    public PCA(int maxPCs) {
        this(maxPCs, 1e-4);
    }

    /**
     * Creates a new object for performing PCA
     * 
     * @param maxPCs    the maximum number of Principal Components to let the
     *                  algorithm learn. The algorithm may stop earlier if all the
     *                  variance has been explained, or the convergence threshold
     *                  has been met. Note, the computable maximum number of PCs is
     *                  limited to the minimum of the number of samples and the
     *                  number of dimensions.
     * @param threshold a convergence threshold, any small value will work. Smaller
     *                  values will not produce more accurate results, but may make
     *                  the algorithm take longer if it would have terminated before
     *                  <tt>maxPCs</tt> was reached.
     */
    public PCA(int maxPCs, double threshold) {
        setMaxPCs(maxPCs);
        setThreshold(threshold);
    }

    /**
     * Performs PCA analysis using the given data set, so that transformations may
     * be performed on future data points.
     * 
     * @param dataSet   the data set to learn from
     * @param maxPCs    the maximum number of Principal Components to let the
     *                  algorithm learn. The algorithm may stop earlier if all the
     *                  variance has been explained, or the convergence threshold
     *                  has been met. Note, the computable maximum number of PCs is
     *                  limited to the minimum of the number of samples and the
     *                  number of dimensions.
     * @param threshold a convergence threshold, any small value will work. Smaller
     *                  values will not produce more accurate results, but may make
     *                  the algorithm take longer if it would have terminated before
     *                  <tt>maxPCs</tt> was reached.
     */
    public PCA(DataSet dataSet, int maxPCs, double threshold) {
        this(maxPCs, threshold);
        fit(dataSet);
    }

    @Override
    public void fit(DataSet dataSet) {
        // Edwad, don't forget. This is: Nonlinear Iterative PArtial Least Squares
        // (NIPALS) algo
        List<Vec> scores = new ArrayList<Vec>();
        List<Vec> loadings = new ArrayList<Vec>();
        // E(0) = X The E-matrix for the zero-th PC

        // Contains the unexplained variance in the data at each step.
        Matrix E = dataSet.getDataMatrix();

        // This is the MAX number of possible Principlal Components
        int PCs = Math.min(dataSet.size(), dataSet.getNumNumericalVars());
        PCs = Math.min(maxPCs, PCs);
        Vec t = getColumn(E);

        double tauOld = t.dot(t);
        Vec p = new DenseVector(E.cols());
        for (int i = 1; i <= PCs; i++) {
            for (int iter = 0; iter < 100; iter++) {
                // 1. Project X onto t to and the corresponding loading p
                // p = (E[i-1]' * t) / (t'*t)
                p.zeroOut();
                E.transposeMultiply(1.0, t, p);
                p.mutableDivide(tauOld);

                // 2. Normalise loading vector p to length 1
                // p = p * (p'*p)^-0.5
                p.mutableMultiply(Math.pow(p.dot(p), -0.5));

                // 3. Project X onto p to find corresponding score vector t
                // t = (E[i-1] p)/(p'*p)
                t = E.multiply(p);
                t.mutableDivide(p.dot(p));

                // 4. Check for convergence.
                double tauNew = t.dot(t);
                if (iter > 0 && Math.abs(tauNew - tauOld) <= threshold * tauNew || iter == 99)// go at least one round
                {
                    scores.add(new DenseVector(t));
                    loadings.add(new DenseVector(p));
                    break;
                }
                tauOld = tauNew;
            }
            // 5. Remove the estimated PC component from E[i-1]
            Matrix.OuterProductUpdate(E, t, p, -1.0);
        }

        P = new DenseMatrix(loadings.size(), loadings.get(0).length());
        for (int i = 0; i < loadings.size(); i++) {
            Vec pi = loadings.get(i);
            for (int j = 0; j < pi.length(); j++)
                P.set(i, j, pi.get(j));
        }
    }

    /**
     * Copy constructor
     * 
     * @param other the transform to copy
     */
    private PCA(PCA other) {
        if (other.P != null)
            this.P = other.P.clone();
        this.maxPCs = other.maxPCs;
        this.threshold = other.threshold;
    }

    /**
     * sets the maximum number of principal components to learn
     * 
     * @param maxPCs the maximum number of principal components to learn
     */
    public void setMaxPCs(int maxPCs) {
        if (maxPCs <= 0)
            throw new IllegalArgumentException("number of principal components must be a positive number, not " + maxPCs);
        this.maxPCs = maxPCs;
    }

    /**
     * 
     * @return maximum number of principal components to learn
     */
    public int getMaxPCs() {
        return maxPCs;
    }

    /**
     * 
     * @param threshold the threshold for convergence of the algorithm
     */
    public void setThreshold(double threshold) {
        if (threshold <= 0 || Double.isInfinite(threshold) || Double.isNaN(threshold))
            throw new IllegalArgumentException("threshold must be in the range (0, Inf), not " + threshold);
        this.threshold = threshold;
    }

    public double getThreshold() {
        return threshold;
    }

    /**
     * Returns the first non zero column
     * 
     * @param x the matrix to get a column from
     * @return the first non zero column
     */
    private static Vec getColumn(Matrix x) {
        Vec t;

        for (int i = 0; i < x.cols(); i++) {
            t = x.getColumn(i);
            if (t.dot(t) > 0)
                return t;
        }

        throw new ArithmeticException("Matrix is essentially zero");
    }

    @Override
    public DataPoint transform(DataPoint dp) {
        DataPoint newDP = new DataPoint(P.multiply(dp.getNumericalValues()), Arrays.copyOf(dp.getCategoricalValues(), dp.numCategoricalValues()), CategoricalData.copyOf(dp.getCategoricalData()));
        return newDP;
    }

    @Override
    public DataTransform clone() {
        return new PCA(this);
    }

}
